# encoding: utf-8

import torch
from torch import Tensor

tensor = torch.tensor([5, 2, 6, 7, 9, 3, 1, 3, 5, 4], dtype=torch.int)
_index = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int)


def arg_max_many(ts: Tensor, count: int = 2):
    ttss = ts.clone()
    temp = []
    for i in range(count):
        _, idx = torch.max(ttss, 0)
        temp.append(idx.item())
        ttss[idx] = -999999

    return temp


result = arg_max_many(tensor, 3)
print(result)
